#include "AnisotropicPointsToImplicit3.h"
 
#include "../Simulation/SphKernels3.h"
#include "../Simulation/SphSystemData3.h"
#include "../Simulation/PointParallelHashGridSearcher3.h"
#include "../Math/SingularValueDecomposition.h"

namespace Engine
{
	inline double p(double distance) 
	{
		const double distanceSquared = distance * distance;
		if (distanceSquared >= 1.0) 
			return 0.0;
		else 
		{
			const double x = 1.0 - distanceSquared;
			return x * x * x;
		}
	}

	inline double wij(double distance, double r) 
	{
		if (distance < r) 
			return 1.0 - cubic(distance / r);
		else 
			return 0.0;
	}

	inline Matrix3x3D vvt(const Vector3D& v) 
	{
		return Matrix3x3D(v.x * v.x, v.x * v.y, v.x * v.z, v.y * v.x, v.y * v.y,
			v.y * v.z, v.z * v.x, v.z * v.y, v.z * v.z);
	}

	inline double w(const Vector3D& r, const Matrix3x3D& g, double gDet) 
	{
		static const double sigma = 315.0 / (64 * kPiD);
		return sigma * gDet * p((g * r).length());
	}

	AnisotropicPointsToImplicit3::AnisotropicPointsToImplicit3(double kernelRadius, double cutOffDensity,
		double positionSmoothingFactor, size_t minNumNeighbors, bool isOutputSdf) : 
		_kernelRadius(kernelRadius),
		_cutOffDensity(cutOffDensity),
		_positionSmoothingFactor(positionSmoothingFactor),
		_minNumNeighbors(minNumNeighbors),
		_isOutputSdf(isOutputSdf) {}

	void AnisotropicPointsToImplicit3::convert(const ConstArrayAccessor1<Vector3D>& points, ScalarGrid3 * output) const
	{
		if (output == nullptr) 
		{
			std::cerr << "Null scalar grid output pointer provided.\n";
			return;
		}

		const auto res = output->resolution();
		if (res.x * res.y * res.z == 0)
		{
			std::cout << "Empty grid is provided.\n";
			return;
		}

		const auto bbox = output->boundingBox();
		if (bbox.isEmpty()) 
		{
			std::cout << "Empty domain is provided.\n";
			return;
		}

		std::cout << "Start converting points to implicit surface.\n";

		const double h = _kernelRadius;
		const double invH = 1 / h;
		const double r = 2.0 * h;

		// Mean estimator for cov. mat.
		//const auto meanNeighborSearcher = PointKdTreeSearcher3::builder().makeShared();
		//meanNeighborSearcher->build(points);

		std::cout << "Built neighbor searcher.";

		SphSystemData3 meanParticles;
		meanParticles.addParticles(points);
		meanParticles.setKernelRadius(r);
		meanParticles.buildNeighborSearcher();
		//meanParticles.setNeighborSearcher(meanNeighborSearcher);
		const auto meanNeighborSearcher = meanParticles.neighborSearcher();

		// Compute G and xMean
		std::vector<Matrix3x3D> gs(points.size());
		Array1<Vector3D> xMeans(points.size());

		parallelFor(kZeroSize, points.size(), [&](size_t i) 
		{
			const auto& x = points[i];
			// Compute xMean
			Vector3D xMean;
			double wSum = 0.0;
			size_t numNeighbors = 0;
			const auto getXMean = [&](size_t, const Vector3D& xj)
			{
				const double wj = wij((x - xj).length(), r);
				wSum += wj;
				xMean += wj * xj;
				++numNeighbors;
			};
			meanNeighborSearcher->forEachNearbyPoint(x, r, getXMean);

			assert(wSum > 0.0);
			xMean /= wSum;

			xMeans[i] = lerp(x, xMean, _positionSmoothingFactor);

			if (numNeighbors < _minNumNeighbors) 
			{
				const auto g = Matrix3x3D::makeScaleMatrix(invH, invH, invH);
				gs[i] = g;
			}
			else 
			{
				// Compute covariance matrix
				// We start with small scale matrix (h*h) in order to
				// prevent zero covariance matrix when points are all
				// perfectly lined up.
				auto cov = Matrix3x3D::makeScaleMatrix(h * h, h * h, h * h);
				wSum = 0.0;
				const auto getCov = [&](size_t, const Vector3D& xj) 
				{
					const double wj = wij((xMean - xj).length(), r);
					wSum += wj;
					cov += wj * vvt(xj - xMean);
				};
				meanNeighborSearcher->forEachNearbyPoint(x, r, getCov);

				cov /= wSum;

				// SVD
				Matrix3x3D u;
				Vector3D v;
				Matrix3x3D w;
				svd(cov, u, v, w);

				// Take off the sign
				v.x = std::fabs(v.x);
				v.y = std::fabs(v.y);
				v.z = std::fabs(v.z);

				// Constrain Sigma
				const double maxSingularVal = v.max();
				const double kr = 4.0;
				v.x = std::max(v.x, maxSingularVal / kr);
				v.y = std::max(v.y, maxSingularVal / kr);
				v.z = std::max(v.z, maxSingularVal / kr);

				const auto invSigma = Matrix3x3D::makeScaleMatrix(1.0 / v);

				// Compute G
				const double scale =
					std::pow(v.x * v.y * v.z, 1.0 / 3.0);  // volume preservation
				const Matrix3x3D g = invH * scale * (w * invSigma * u.transposed());
				gs[i] = g;
			}
		});

		std::cout << "Computed G and means.\n";

		// SPH estimator
		meanParticles.setKernelRadius(h);
		meanParticles.updateDensities();
		const auto d = meanParticles.densities();
		const double m = meanParticles.mass();

		SphSystemData3 meanParticles2;
		meanParticles2.addParticles(xMeans.constAccessor());
		meanParticles2.setKernelRadius(r);
		meanParticles2.buildNeighborSearcher();
		const auto meanNeighborSearcher2 = meanParticles2.neighborSearcher();

		// Compute SDF
		auto temp = output->clone();
		temp->fill([&](const Vector3D& x) 
		{
			double sum = 0.0;
			meanNeighborSearcher2->forEachNearbyPoint(
				x, r, [&](size_t i, const Vector3D& neighborPosition) 
			{
				sum += m / d[i] * w(neighborPosition - x, gs[i], gs[i].determinant());
			});

			return _cutOffDensity - sum;
		});

		std::cout << "Computed SDF.\n";

		//if (_isOutputSdf) 
		//{
		//	FmmLevelSetSolver3 solver;
		//	solver.reinitialize(*temp, kMaxD, output);

		//	std::cout << "Completed einitialization.\n";
		//}
		//else 
		{
			temp->swap(output);
		}

		std::cout << "Done converting points to implicit surface.\n";
	}
}